--- title: Training loop keywords: fastai sidebar: home_sidebar nb_path: "nbs/14_train_ae.ipynb" ---
%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
cfg = OmegaConf.load('../config/experiment/N2_352_2.yaml')
# cfg =OmegaConf.load(default_conf)
psf, noise, micro = load_psf_micro_psf_noise(cfg)
img_3d = load_tiff_image(cfg.data_path.image_path)[0]
estimate_backg = hydra.utils.instantiate(cfg.bg_estimation)
roi_mask = get_roi_mask(img_3d, tuple(cfg.roi_mask.pool_size), percentile= cfg.roi_mask.percentile)
rand_crop = RandomCrop3D((cfg.random_crop.crop_sz,cfg.random_crop.crop_sz,cfg.random_crop.crop_sz), roi_mask)
probmap_generator = ScaleTensor(low=cfg.prob_generator.low,
high=cfg.prob_generator.high,
data_min = img_3d.min(),
data_max = img_3d.max())
ds = DecodeDataset(path = cfg.data_path.image_path,
dataset_tfms = [rand_crop],
rate_transform = probmap_generator,
bg_transform = estimate_backg,
device='cuda:0',
num_iter=cfg.dataloader.num_iter * cfg.dataloader.bs)
decode_dl = DataLoader(ds, batch_size=2, num_workers=0)
inp_offset, inp_scale = get_forward_scaling(img_3d)
micro = Microscope(parametric_psf=[psf], noise=noise, multipl=cfg.microscope.multipl).cuda()
psf .to('cuda')
micro.to('cuda')
plot_3d_projections(psf.psf_volume[0])
if cfg.evaluation is not None:
eval_dict = dict(cfg.evaluation)
eval_dict['crop_sl'] = eval(eval_dict['crop_sl'],{'__builtins__': None},{'s_': np.s_})
eval_dict['px_size'] = list(eval_dict['px_size'])
else:
eval_dict = None
save_dir = Path(cfg.output.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
OmegaConf.save(cfg, cfg.output.save_dir + '/train.yaml')
model_sl = load_model_state(cfg, 'model_sl.pkl').cuda()
micro.load_state_dict(torch.load(Path(cfg.output.save_dir)/'microscope_sl.pkl'))
opt_sl = AdamW(model_sl.parameters(), lr=cfg.supervised.lr)
opt_sl.load_state_dict(torch.load(Path(cfg.output.save_dir)/'opt_sl.pkl'))
scheduler_sl = torch.optim.lr_scheduler.StepLR(opt_sl, step_size=1000, gamma=0.5)
ae_param = list(micro.parameters()) + list(psf.parameters()) + list(model_sl.parameters())
opt_ae = AdamW(ae_param, lr=1e-4)
scheduler_ae = torch.optim.lr_scheduler.StepLR(opt_ae, step_size=1000, gamma=0.5)
gt_img, gt_df = load_from_eval_dict(eval_dict)
with torch.no_grad():
res_gt = model_sl(gt_img[None].cuda())
locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae = model_output_to_micro_input(res_gt, threshold=0.1)
ae_img = micro(locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae)
pred_gt_df = model_output_to_df(res_gt, 0.1, px_size=eval_dict['px_size'])
free_mem()
gt_fig = gt_plot(gt_img, pred_gt_df, gt_df, eval_dict['px_size'],ae_img[0]+res_gt['background'][0])
plt.show()
train_ae(model=model_sl,
dl=decode_dl,
num_iter=cfg.autoencoder.num_iter,
optim_net=opt_sl,
optim_psf=opt_ae,
min_int=cfg.pointprocess.min_int,
psf=psf,
sched_net=scheduler_sl,
sched_psf=scheduler_ae,
microscope=micro,
log_interval=cfg.supervised.log_interval,
save_dir=cfg.output.save_dir,
log_dir=cfg.output.log_dir,
bl_loss_scale=cfg.supervised.bl_loss_scale,
p_quantile=cfg.supervised.p_quantile,
grad_clip=cfg.supervised.grad_clip,
eval_dict=eval_dict)
!nbdev_build_lib